import cv2
import numpy as np
import scipy.io
import os
import pickle
import torch
import tqdm

import torch.utils.data
from PIL import Image
from copy import deepcopy
from torch.utils.data import DataLoader
import torch.multiprocessing

from vision_utils.my_utils import visualize_boxes_yolo, visualize_boxes
from vision_utils.yolov3.utils.general import xyxy2xywh, xywhn2xyxy
from vision_utils.yolov3.utils.datasets import letterbox
from vision_utils import vis_utils, my_utils
import vision_utils.transforms as T

mappings = {
	'0-background': '0-background',
	'aeroplane': 'aeroplane',
	'bicycle': 'bicycle',
	'bird': 'bird',
	'boat': 'boat',
	'bottle': 'bottle',
	'bus': 'bus',
	'car': 'car',
	'cat': 'cat',
	'chair': 'chair',
	'cow': 'cow',
	'diningtable': 'table',
	'table': 'table',
	'dog': 'dog',
	'horse': 'horse',
	'motorbike': 'motorbike',
	'person': 'person',
	'pottedplant': 'pottedplant',
	'sheep': 'sheep',
	'sofa': 'sofa',
	'train': 'train',
	'tvmonitor': 'tvmonitor',
	'lfoot': 'foot',
	'hair': 'hair',
	'fwheel': "wheel",
	'lhand': "hand",
	'rfoot': "foot",
	'llleg': "leg",
	'chainwheel': 'chainwheel',
	'ruarm': 'arm',
	'rlarm': 'arm',
	'rlleg': 'leg',
	'rhand': 'hand',
	'llarm': 'arm',
	'luarm': 'arm',
	'luleg': 'leg',
	'saddle': 'saddle',
	'bwheel': "wheel",
	'handlebar': "handlebar",
	'ruleg': "leg",
	"head": "head",
	"torso": "torso",
	"beak": "beak",
	"rleg": "leg",
	"tail": "tail",
	"frontside": "frontside",
	"rightside": "rightside",
	"leftside": "leftside",
	'window_1': "window",
	'window_2': "window",
	'window_3': "window",
	'window_4': "window",
	'window_5': "window",
	'window_6': "window",
	'window_7': "window",
	'window_8': "window",
	'window_9': "window",
	'window_10': "window",
	'window_11': "window",
	'window_12': "window",
	'window_13': "window",
	'window_14': "window",
	'window_15': "window",
	'window_16': "window",
	'window_17': "window",
	'window_18': "window",
	'window_19': "window",
	'window_20': "window",
	'window_21': "window",
	'window_22': "window",
	'window_23': "window",
	'window_24': "window",
	'window_25': "window",
	'window_26': "window",
	'window_27': "window",
	'window_28': "window",
	'window_29': "window",
	'headlight_1': "headlight",
	'headlight_2': "headlight",
	'headlight_3': "headlight",
	'headlight_4': "headlight",
	'headlight_5': "headlight",
	'headlight_6': "headlight",
	'headlight_7': "headlight",
	'headlight_8': "headlight",
	'headlight_9': "headlight",
	"wheel_1": "wheel",
	"wheel_2": "wheel",
	"wheel_3": "wheel",
	"wheel_4": "wheel",
	"wheel_5": "wheel",
	"wheel_6": "wheel",
	"wheel_7": "wheel",
	"wheel_8": "wheel",
	"wheel_9": "wheel",
	"door_1": "door",
	"door_2": "door",
	"door_3": "door",
	"door_4": "door",
	"fliplate": "plate",
	"bliplate": "plate",
	'rightmirror': "mirror",
	'leftmirror': "mirror",
	"lear": "ear",
	"rear": "ear",
	"leye": "eye",
	"reye": "eye",
	"lebrow": "ebrow",
	"rebrow": "ebrow",
	"mouth": "mouth",
	"body": "body",
	"nose": "nose",
	"neck": "neck",
	"pot": "pot",
	"plant": "plant",
	"cap": "cap",
	"lwing": "wing",
	"rwing": "wing",
	"muzzle": "muzzle",
	"lfpa": "paw",
	"rfpa": "paw",
	"rbpa": "paw",
	"lbpa": "paw",
	'lfleg': 'leg',
	'rfleg': 'leg',
	'rbleg': 'leg',
	'lbleg': 'leg',
	'lleg': 'leg',
	'screen': 'screen',
	'coach_1': 'coach',
	'coach_2': 'coach',
	'coach_3': 'coach',
	'coach_4': 'coach',
	'coach_5': 'coach',
	'coach_6': 'coach',
	'coach_7': 'coach',
	'coach_8': 'coach',
	'coach_9': 'coach',
	'lflleg': 'leg',
	'lfuleg': 'leg',
	'rflleg': 'leg',
	'lblleg': 'leg',
	'lbuleg': 'leg',
	'rfuleg': 'leg',
	'rbuleg': 'leg',
	'rblleg': 'leg',
	'lfho': 'hoof',
	'rfho': 'hoof',
	'lbho': 'hoof',
	'rbho': 'hoof',
	'stern': 'stern',
	'engine_1': 'engine',
	'engine_2': 'engine',
	'engine_3': 'engine',
	'engine_4': 'engine',
	'engine_5': 'engine',
	'engine_6': 'engine',
	'engine_7': 'engine',
	'engine_8': 'engine',
	'engine_9': 'engine',
	'cleftside_1': "leftside",
	'cleftside_2': "leftside",
	'cleftside_3': "leftside",
	'cleftside_4': "leftside",
	'cleftside_5': "leftside",
	'cleftside_6': "leftside",
	'cleftside_7': "leftside",
	'cleftside_8': "leftside",
	'cleftside_9': "leftside",
	'crightside_1': "rightside",
	'crightside_2': "rightside",
	'crightside_3': "rightside",
	'crightside_4': "rightside",
	'crightside_5': "rightside",
	'crightside_6': "rightside",
	'crightside_7': "rightside",
	'crightside_8': "rightside",
	'crightside_9': "rightside",
	'cfrontside_1': "frontside",
	'cfrontside_2': "frontside",
	'cfrontside_3': "frontside",
	'cfrontside_4': "frontside",
	'cfrontside_5': "frontside",
	'cfrontside_6': "frontside",
	'cfrontside_7': "frontside",
	'cfrontside_8': "frontside",
	'cfrontside_9': "frontside",
	'hfrontside': "frontside",
	'hleftside': "leftside",
	'hrightside': "rightside",
	'hroofside': "roofside",
	'hbackside': "backside",
	'backside': "backside",
	"croofside_1": "roofside",
	"roofside": "roofside",
	"croofside_2": "roofside",
	"croofside_3": "roofside",
	"croofside_4": "roofside",
	"croofside_5": "roofside",
	"croofside_6": "roofside",
	"croofside_7": "roofside",
	"croofside_8": "roofside",
	"croofside_9": "roofside",
	"lhorn": "horn",
	"rhorn": "horn",
	"cbackside_1": "backside",
	"cbackside_2": "backside",
	"cbackside_3": "backside",
	"cbackside_4": "backside",
	"cbackside_5": "backside",
	"cbackside_6": "backside",
	"cbackside_7": "backside",
	"cbackside_8": "backside",
	"cbackside_9": "backside",
	"train_head": "train_head",
	"aeroplane_body": "aeroplane_body",
	"bottle_body": "bottle_body",
}

name_list = np.unique(np.asarray([v for v in mappings.values()]))
name_ids = {name: i for i, name in enumerate(name_list)}

data_folder = os.path.join("data", "PascalPart", "Annotations", "Annotations_Part")
mask_folder = os.path.join("data", "PascalPart", "Masks")
imgs_folder = os.path.join("data", "PascalPart", "Images")

img_list = os.listdir(imgs_folder)
mask_list = os.listdir(mask_folder)

bad_targets_idx = [125,   154,   181,   287,   396,   400,   447,   494,   935,  1124,  1249,  1319,  2028,  2046,
				   2374,  2619,  2661,  2722,  2924,  2927,  2974,  3107,  3457,  3558,  3680,  3790,  4192,  4335,
				   4351,  4530,  4700,  4860,  4902,  5016,  5017,  5250,  5488,  5588,  5789,  6039,  6066,  6101,
				   6108,  6648, 6773,  6792,  6874,  7142,  7327,  7346,  7355,  7449,  7772,  7964,  8231,  8248,
				   8450,  8755,  8848,  9111,  9535, 10036]


def create_pascal_part_dataset(override=False):
	if os.path.exists(mask_folder):
		if not override:
			return
	else:
		os.makedirs(mask_folder)
	print("Creating Pascalpart dataset")

	pbar = tqdm.trange(len(os.listdir(data_folder)))
	for i, filename in enumerate(os.listdir(data_folder)):
		pbar.update()
		path = os.path.join(data_folder, filename)
		anno = scipy.io.loadmat(path)['anno'][0][0]
		img_name = anno[0].item()
		objs = anno[1][0]
		masks = []
		for obj in objs:
			object_name = obj[0][0]
			object_name = mappings[object_name]
			object_id = name_ids[object_name]
			object_mask = obj[2]
			mask = {
				'target_id': object_id,
				'target_name': object_name,
				'mask': object_mask
			}
			masks.append(mask.copy())
			part_list = obj[3]
			if len(part_list) > 0:
				for part in part_list[0]:
					part_name = part[0].item()
					if object_name == "train" and part_name == "head":
						part_name = "train_head"
					elif object_name == "aeroplane" and part_name == "body":
						part_name = "aeroplane_body"
					elif object_name == "bottle" and part_name == "body":
						part_name = "bottle_body"
					part_name = mappings[part_name]
					part_id = name_ids[part_name]
					part_mask = part[1]
					mask['target_id'] = part_id
					mask['target_name'] = part_name
					mask['mask'] = part_mask
					masks.append(mask.copy())
		final_masks = []
		for mask in masks:
			final_masks.append(mask['mask'] * mask['target_id'])
		final_masks = np.stack(final_masks)

		filename = os.path.join(mask_folder, img_name + "_mask.pckl")
		with open(filename, "wb") as f:
			pickle.dump(final_masks, f)
	pbar.close()

	return


def delete_not_annotated_images():
	i = 0
	for image in img_list:
		check_list = [1 for mask in mask_list if image.split(".")[0] in mask]
		if len(check_list) == 0:
			print(image.split(".")[0])
			os.remove(os.path.join(imgs_folder, image))
			i +=1
	print(f"{i} deleted images")


def check_img_mask_consistency():
	i, j = 0, 0
	for image, mask in zip(img_list, mask_list):
		if image.split(".")[0] not in mask:
			print(image, mask)
			i += 1
		else:
			j += 1
	print(f"{j} Consistent mask-image")
	print(f"{i} Inconsistent mask-image")


def check_box_consistency():
	pbar = tqdm.trange(len(mask_list))
	dataset = PascalPartDataset('data/PascalPart',
								my_utils.get_transform(train=True), load=True)
	targets = dataset.targets
	for target in targets:
		boxes = target['boxes']
		pbar.update()
		area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
		negative_areas = area < 0.1
		assert (area >= 0.1).all(), f"Error in bounding box area, {area[negative_areas]} {boxes[negative_areas]}"
		degenerate_boxes = torch.tensor([box[3] <= box[1] or box[2] <= box[0] for box in boxes.numpy().tolist()])
		assert not degenerate_boxes.any(), f"Degenerated box {boxes[degenerate_boxes, :]}"
	pbar.close()


class PascalPartDataset(torch.utils.data.Dataset):
	def __init__(self, root, transforms=None, load=True, target_file="targets.npy"):
		self.root = root
		self.transforms = transforms
		# load all image files, sorting them to
		# ensure that they are aligned
		self.imgs = list(sorted(os.listdir(os.path.join(root, "Images"))))
		self.masks = list(sorted(os.listdir(os.path.join(root, "Masks"))))
		self.targets = [None] * len(self)
		self.target_file = target_file

		if load:
			if os.path.isfile(os.path.join(root, self.target_file)):
				self.targets = np.load(os.path.join(root, self.target_file),
									   allow_pickle=True)
				print("Targets loaded")
			else:
				self.build_targets()

		for idx in sorted(bad_targets_idx, reverse=True):
			self.imgs.pop(idx)
			self.targets = np.delete(self.targets, idx)

	def __getitem__(self, idx):
		# check image mask consistency
		image = self.imgs[idx]  #, self.masks[idx]
		# assert image.split(".")[0] in mask, f"Error in loading image {image} or mask {mask}"

		# load image
		img_path = os.path.join(self.root, "Images", self.imgs[idx])
		try:
			img = Image.open(img_path).convert("RGB")
		except OSError as e:
			print(image)
			raise e

		# load target only if it is not already in memory
		if self.targets[idx] is None:
			# print("Think about loading the dataset ahead, it may speed up a lot loading procedure")
			mask_path = os.path.join(self.root, "Masks", self.masks[idx])
			masks = np.load(mask_path, allow_pickle=True)
			# instances are encoded as different channels
			num_objs = masks.shape[0]

			boxes = []
			labels = []
			for i in range(num_objs):
				pos = np.where(masks[i] !=0)
				xmin = np.min(pos[1])
				xmax = np.max(pos[1])
				ymin = np.min(pos[0])
				ymax = np.max(pos[0])
				boxes.append([xmin, ymin, xmax, ymax])
				obj_ids = np.unique(masks[i])
				try:
					labels.append(obj_ids[obj_ids != 0].item())
				except BaseException:
					print(f"Error in getting label {obj_ids}")
			boxes = torch.as_tensor(boxes, dtype=torch.float32)
			labels = torch.as_tensor(labels, dtype=torch.int64)

			image_id = torch.tensor([idx])
			area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
			degenerated_aerea = area <= 0.1
			if degenerated_aerea.any():
				boxes = boxes[area >= 0.1, :]
				labels = labels[area >= 0.1]

			# check for crowded instances (not counted in test evaluation)
			iscrowd = torch.tensor([num_objs > 50] * num_objs, dtype=torch.int64)

			target = {}
			target["boxes"] = boxes
			target["labels"] = labels
			target["image_id"] = image_id
			target["area"] = area
			target["iscrowd"] = iscrowd
		else:
			# print("Target already in memory")
			target = self.targets[idx]

		if self.transforms is not None:
			img, target = self.transforms(img, target)

		return img, target

	def __len__(self):
		return len(self.imgs)

	def build_targets(self):
		print("Building dataset targets")
		# we don't use data augmentation here
		transforms = self.transforms
		self.transforms = my_utils.get_transform(train=False)

		pbar = tqdm.trange(len(self))
		for i in pbar:
			target = PascalPartDataset.__getitem__(self, i)[1]
			self.targets[i] = target
		pbar.close()

		np.save(os.path.join(self.root, self.target_file),
				self.targets, allow_pickle=True)

		# restoring transforms
		self.transforms = transforms


class YoloPascalPartDataset(PascalPartDataset):
	def __init__(self, *args, image_size=416, **kwargs):
		self.verbose = False
		self.image_size = image_size
		super().__init__(*args, **kwargs)

	def __getitem__(self, idx):
		target = deepcopy(self.targets[idx])

		# return an image already with max dimension equal to image_size
		image, (h0, w0), (h, w) = self.load_image(idx)
		boxes = target['boxes']
		boxes[:, [1, 3]] = boxes[:, [1, 3]] / h0 * h
		boxes[:, [0, 2]] = boxes[:, [0, 2]] / w0 * w
		if self.verbose:
			visualize_boxes_yolo(torch.as_tensor(image).int(), target,
								 name_list, show=True, to_permute=False)
		# add padding and rescale the image (?) to fit (image_size, image_size)
		image, ratio, pad = letterbox(image, self.image_size,
									  auto=False, scaleup=False)
		boxes[:, [1, 3]] = boxes[:, [1, 3]] * ratio[1] + pad[1]
		boxes[:, [0, 2]] = boxes[:, [0, 2]] * ratio[0] + pad[0]
		if self.verbose:
			visualize_boxes_yolo(torch.as_tensor(image).int(), target,
								 name_list, show=True, to_permute=False)

		# convert to yolo notation (center of the box, width and height) and normalize
		boxes = xyxy2xywh(boxes)  # convert xyxy to xywh
		boxes[:, [1, 3]] /= image.shape[0]  # normalized height 0-1
		boxes[:, [0, 2]] /= image.shape[1]  # normalized width 0-1

		labels = target['labels'] - 1 # YOLO targets start from 0!
		img_idx = torch.zeros_like(labels)
		targets = torch.cat([img_idx.unsqueeze(dim=1),
							 labels.unsqueeze(dim=1),
							 boxes], dim=1)
		assert targets.shape[1] == 6, 'labels require 5 columns each'
		assert (targets >= 0).all(), 'negative labels'
		assert (targets[:, 2:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
		# assert np.unique(targets, axis=0).shape[0] == targets.shape[0], 'duplicate labels'

		image = image[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
		image = np.ascontiguousarray(image)
		image_t = torch.from_numpy(image).float()

		shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

		return image_t, targets, shapes, idx

	def load_image(self, index):
		# loads 1 image from dataset, returns img, original hw, resized hw
		img_path = os.path.join(self.root, "Images", self.imgs[index])
		img = cv2.imread(img_path)  # BGR
		assert img is not None, 'Image Not Found ' + img_path
		h0, w0 = img.shape[:2]  # orig hw
		r = self.image_size / max(h0, w0)  # ratio
		if r != 1:  # if sizes are not equal
			img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
							 interpolation=cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR)
		return img, (h0, w0), img.shape[:2]  # img, hw_original, hw_resized

	def get_coco_compatible_dataset(self):
		load = True if self.targets[0] is not None else False
		return PascalPartDataset(self.root, self.transforms, load)


def check_dataset():
	dataset = PascalPartDataset('data/PascalPart', my_utils.get_transform(train=True))
	data_loader = DataLoader(
		dataset, batch_size=2, shuffle=True, num_workers=8,
		collate_fn=vis_utils.collate_fn)
	pbar = tqdm.trange(len(data_loader))
	for i, data in enumerate(data_loader):
		pbar.update()


def check_dataset_loading():
	PascalPartDataset('data/PascalPart',
					  my_utils.get_transform(train=True), load=True)


def create_dog_vs_person_dataset():
	import shutil
	person_parts = ['Person', 'Torso', 'Leg', 'Head', 'Ear', 'Eye', 'Neck', 'Nose', 'Mouth', 'Foot', 'Ebrow', 'Hair', 'Arm', 'Hand']
	dog_parts = ['Dog', 'Torso', 'Leg', 'Head', 'Ear',  'Eye', 'Neck', 'Nose', 'Muzzle', 'Paw', 'Tail']

	dataset = PascalPartDataset('data/PascalPart',
	                            my_utils.get_transform(train=True), load=True)
	new_dir = os.path.join('data', 'DogvsPerson')
	old_img_dir = os.path.join(dataset.root, "Images")
	new_img_dir = os.path.join(new_dir, "Images")
	new_img_dir_boxes = os.path.join(new_dir, "Images_prediction")
	if not os.path.exists(new_dir):
		os.mkdir(new_dir)
	if not os.path.exists(new_img_dir):
		os.mkdir(new_img_dir)
	if not os.path.exists(new_img_dir_boxes):
		os.mkdir(new_img_dir_boxes)

	dog_counter, person_counter = 0, 0
	new_targets = []
	pbar = tqdm.trange(len(dataset))
	j = -1
	for image, target in dataset:
		pbar.update()
		for idx in target['labels']:
			label = name_list[idx]
			if label == "dog":
				if len(target['labels']) < 10:
					j += 1
					save_path = os.path.join(new_img_dir_boxes, f"{j}.png")
					visualize_boxes(image, target, name_list, show=False, add_score=False, save_path=save_path)
					break
			# if label == "dog" or label == "person":
				# shutil.copy(os.path.join(old_img_dir, image),
				#             os.path.join(new_img_dir, image))

				# new_target = {
				# 	"boxes": [],
				# 	"labels": [],
				# 	"image_id": [],
				# 	"area": [],
				# 	"iscrowd": []
				# }
				# for i in range(len(target['labels'])):
				# 	if target['labels'][i].uppercase() in dog_parts or \
				# 	   target['labels'][i].uppercase() in person_parts:
				# 		if target['labels'][i] == "dog":
				# 			dog_counter += 1
				# 		if target['labels'][i] == "person":
				# 			if dog_counter > person_counter:
				# 				person_counter += 1
				# 		new_target['boxes'].append(target['box'][i])
				# 		new_target['labels'].append(target['labels'][i])
				# 		new_target['image_id'].append()
				#
				# new_targets.append(new_target)
				# j += 1
				# break
	np.save(os.path.join(new_dir, dataset.target_file),
	        new_targets, allow_pickle=True)


if __name__ == "__main__":
	# print(name_list)
	# print(name_ids)
	#
	# create_pascal_part_dataset(override=True)
	# print(os.listdir(mask_folder))
	#
	# delete_not_annotated_images()
	#
	# check_img_mask_consistency()

	# check_dataset_loading()

	create_dog_vs_person_dataset()

	# import time
	# dataset = PascalPartDataset('data/PascalPart',
	#                             my_utils.get_transform(train=True), load=False)
	# data_loader = DataLoader(dataset, batch_size=2, num_workers=0,
	#                          shuffle=True, collate_fn=utils.collate_fn)
	#
	# t = time.time()
	# _, _ = data_loader.__iter__().next()
	# print(f"Elapsed {time.time() - t} s")
	#
	# dataset = PascalPartDataset('data/PascalPart',
	#                             my_utils.get_transform(train=True), load=True)
	# data_loader = DataLoader(dataset, batch_size=2, num_workers=0,
	#                          shuffle=True, collate_fn=utils.collate_fn)
	# t = time.time()
	# _, _ = data_loader.__iter__().next()
	# print(f"Elapsed {time.time() - t} s")
